{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cc20c872",
   "metadata": {},
   "source": [
    "## autodl 环境准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "543cc8f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install datasets transformers scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b4f2ffcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# os.environ['no_proxy'] = 'localhost,127.0.0.1,modelscope.com,aliyuncs.com,tencentyun.com,wisemodel.cn'\n",
    "# os.environ['NO_PROXY'] = 'localhost,127.0.0.1,modelscope.com,aliyuncs.com,tencentyun.com,wisemodel.cn'\n",
    "\n",
    "# os.environ['http_proxy'] = 'http://100.72.64.19:12798'\n",
    "# os.environ['HTTP_PROXY'] = 'http://100.72.64.19:12798'\n",
    "\n",
    "# os.environ['https_proxy'] = 'http://100.72.64.19:12798'\n",
    "# os.environ['HTTPS_PROXY'] = 'http://100.72.64.19:12798'\n",
    "os.environ['HF_ENDPOINT']=\"https://hf-mirror.com\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43492e99",
   "metadata": {},
   "source": [
    "## 初始"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "50f4ec5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch import optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "from transformers import AutoTokenizer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "beadb6db",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 256\n",
    "NUM_EPOCHS = 6\n",
    "MAX_SEQ_LEN = 50\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "learning_rate=5e-4\n",
    "weight_decay=1e-5\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e5a4e3f",
   "metadata": {},
   "source": [
    "## 数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d201a8a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from datasets.download import DownloadConfig\n",
    "\n",
    "\n",
    "full_dataset = load_dataset(\"wmt/wmt17\", \"zh-en\", split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2570020f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "从 总共 1141860 个数据中选择了 1141860 个\n",
      "----------------------------------------------------------------------------------------------------\n",
      "UNSEQ number: E-00206\n",
      "联合国索赔序号: E-00206\n",
      "`Come, then!\n",
      "曼内特到哪儿去了？\n",
      "Rule 57\n",
      "第 57 条\n"
     ]
    }
   ],
   "source": [
    "import re\n",
    "\n",
    "def is_valid_trans(x) -> bool:\n",
    "    en = x[\"translation\"][\"en\"]\n",
    "    zh = x[\"translation\"][\"zh\"]\n",
    "    if not en or not zh:\n",
    "        return False\n",
    "    if len(en) < 3 or len(zh) < 3:\n",
    "        return False\n",
    "    # 核心过滤条件\n",
    "    if len(en) > 100 or len(zh) > 100:\n",
    "        return False\n",
    "    ratio = len(en) / len(zh)\n",
    "    if ratio < 0.5 or ratio > 2:\n",
    "        return False\n",
    "    if not re.search(r'[\\u4e00-\\u9fff]', zh):\n",
    "        return False\n",
    "    return True\n",
    "\n",
    "full_dataset = full_dataset.filter(is_valid_trans, num_proc=10)\n",
    "\n",
    "dataset = full_dataset.select(range(len(full_dataset)))\n",
    "dataset = dataset.shuffle(seed=20250709)\n",
    "\n",
    "print(f\"从 总共 {len(full_dataset)} 个数据中选择了 {len(dataset)} 个\")\n",
    "\n",
    "sample = dataset.select(range(3))\n",
    "print(\"-\"*100)\n",
    "for i in sample:\n",
    "    print(i[\"translation\"][\"en\"])\n",
    "    print(i[\"translation\"][\"zh\"])\n",
    "    # print(\"-\"*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "49223f55",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class WmtDataset(Dataset):\n",
    "    def __init__(self, hf_dataset, tokenizer_en, tokenizer_zh, max_length=100):\n",
    "        self.dataset = hf_dataset\n",
    "        self.tokenizer_en = tokenizer_en\n",
    "        self.tokenizer_zh = tokenizer_zh\n",
    "        self.max_length = max_length\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        item = self.dataset[idx]\n",
    "        en_text = item[\"translation\"][\"en\"]\n",
    "        zh_text = item[\"translation\"][\"zh\"]\n",
    "\n",
    "        en_tokens = self.tokenizer_en(en_text, \n",
    "                                        max_length=self.max_length, \n",
    "                                        padding='max_length', \n",
    "                                        truncation=True, \n",
    "                                        return_tensors='pt')\n",
    "            \n",
    "        zh_tokens = self.tokenizer_zh(zh_text, \n",
    "                                        max_length=self.max_length, \n",
    "                                        padding='max_length', \n",
    "                                        truncation=True, \n",
    "                                        return_tensors='pt')\n",
    "            \n",
    "        return en_tokens['input_ids'].squeeze(), zh_tokens['input_ids'].squeeze(),\n",
    "        \n",
    "\n",
    "# bert-base-multilingual-cased 有11万单词，太大了\n",
    "encoder_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "decoder_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-chinese\")\n",
    "encoder_vocab_size = encoder_tokenizer.vocab_size\n",
    "decoder_vocab_size = decoder_tokenizer.vocab_size\n",
    "\n",
    "total = len(dataset)\n",
    "split_idx = int(len(dataset)* 0.95)\n",
    "train_hf_dataset = dataset.select(range(0, split_idx))\n",
    "test_hf_dataset = dataset.select(range(split_idx, total))\n",
    "train_data = WmtDataset(train_hf_dataset, tokenizer_en=encoder_tokenizer, tokenizer_zh=decoder_tokenizer)\n",
    "val_data = WmtDataset(test_hf_dataset, tokenizer_en=encoder_tokenizer, tokenizer_zh=decoder_tokenizer)\n",
    "\n",
    "\n",
    "train_iter = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=12,pin_memory=True if torch.cuda.is_available() else False)\n",
    "val_iter = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=12,pin_memory=True if torch.cuda.is_available() else False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34113774",
   "metadata": {},
   "source": [
    "## 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7c0d5c57",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size=512, hidden_size=1024, num_layers=2, dropout=0.1):\n",
    "        super(Encoder, self).__init__()\n",
    "        \n",
    "        self.vocab_size = vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "        \n",
    "        # Embedding layer to convert token IDs to dense vectors\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)\n",
    "        \n",
    "        # GRU layer for processing sequences\n",
    "        self.rnn = nn.GRU(embed_size, hidden_size, num_layers, \n",
    "                        batch_first=True, dropout=dropout, bidirectional=False)\n",
    "        \n",
    "    def forward(self, input_seq):\n",
    "        # Convert token IDs to embeddings\n",
    "        embedded = self.embedding(input_seq)  # [batch_size, seq_len, embed_size]\n",
    "        \n",
    "        # Pass through GRU\n",
    "        outputs, hidden = self.rnn(embedded)\n",
    "        \n",
    "        # outputs: [batch_size, seq_len, hidden_size]\n",
    "        # hidden: [num_layers, batch_size, hidden_size] \n",
    "        \n",
    "        return outputs, hidden\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size=512, hidden_size=1024, num_layers=2, dropout=0.1):\n",
    "        super(Decoder, self).__init__()\n",
    "        \n",
    "        self.vocab_size = vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "        \n",
    "        # Embedding layer for target tokens\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)\n",
    "        \n",
    "        # GRU layer for generating sequences\n",
    "        self.rnn = nn.GRU(embed_size, hidden_size, num_layers, \n",
    "                        batch_first=True, dropout=dropout, bidirectional=False)\n",
    "        \n",
    "        # Output projection layer to vocabulary\n",
    "        self.output_projection = nn.Linear(hidden_size, vocab_size)\n",
    "        \n",
    "    def forward(self, input_token, hidden):\n",
    "        embedded = self.embedding(input_token)  # [batch_size, seq_len, embed_size]\n",
    "        \n",
    "        # Process entire sequence in parallel\n",
    "        gru_out, final_hidden = self.rnn(embedded, hidden)\n",
    "        # gru_out: [batch_size, seq_len, hidden_size]\n",
    "        \n",
    "        # Project to vocabulary for all timesteps\n",
    "        outputs = self.output_projection(gru_out)  # [batch_size, seq_len, vocab_size]\n",
    "        \n",
    "        return outputs, final_hidden\n",
    "\n",
    "\n",
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self, encoder_vocab_size, decoder_vocab_size, embed_size=512, \n",
    "                hidden_size=1024, num_layers=2, dropout=0.1):\n",
    "        super(Seq2Seq, self).__init__()\n",
    "        \n",
    "        self.encoder_vocab_size = encoder_vocab_size\n",
    "        self.decoder_vocab_size = decoder_vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        # Initialize encoder and decoder\n",
    "        self.encoder = Encoder(encoder_vocab_size, embed_size, hidden_size, num_layers, dropout)\n",
    "        self.decoder = Decoder(decoder_vocab_size, embed_size, hidden_size, num_layers, dropout)\n",
    "        \n",
    "    def forward(self, source_seq, target_seq):\n",
    "        \n",
    "        _, hidden = self.encoder(source_seq)\n",
    "        \n",
    "        outputs, _ = self.decoder(target_seq, hidden)\n",
    "        \n",
    "        return outputs\n",
    "    \n",
    "    def generate(self, source_seq, beam_size=1, max_length=100, start_token=101, end_token=102):\n",
    "        if beam_size > 1:\n",
    "            return self.beam_search(source_seq, beam_size, max_length, start_token, end_token)\n",
    "        self.eval()\n",
    "        batch_size = source_seq.size(0)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            # Encode source sequence\n",
    "            _, hidden = self.encoder(source_seq)\n",
    "            \n",
    "            # Initialize with start token\n",
    "            decoder_input = torch.full((batch_size, 1), start_token, dtype=torch.long).to(source_seq.device)\n",
    "            \n",
    "            # Store generated tokens\n",
    "            generated_tokens = []\n",
    "            \n",
    "            for _ in range(max_length):\n",
    "                # (bs, 1, vocab), (bs, l, hidden)\n",
    "                output, hidden = self.decoder(decoder_input, hidden)\n",
    "                \n",
    "                # Get the token with highest probability\n",
    "                next_token = output.argmax(dim=2)\n",
    "                generated_tokens.append(next_token)\n",
    "                \n",
    "                # Use predicted token as next input\n",
    "                decoder_input = next_token\n",
    "                \n",
    "                # Stop if all sequences generated EOS token\n",
    "                if torch.all(next_token == end_token):\n",
    "                    break\n",
    "            \n",
    "            # Concatenate all generated tokens\n",
    "            generated_seq = torch.cat(generated_tokens, dim=1)\n",
    "            \n",
    "        return generated_seq\n",
    "    def beam_search(self, source_seq, beam_size=3, max_length=100, start_token=101, end_token=102):\n",
    "        self.eval()\n",
    "        assert 1 == source_seq.size(0) # 不支持并行\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            _, hidden = self.encoder(source_seq)\n",
    "            # (seq, prob, hidden)\n",
    "            device = next(self.parameters()).device\n",
    "            initial_seq = torch.tensor([[start_token]], dtype=torch.long).to(device)\n",
    "            beams = [(initial_seq, 0.0, hidden)]\n",
    "            \n",
    "            for _ in range(max_length):\n",
    "                new_beams = []\n",
    "                for seq, log_p, h in beams:\n",
    "                    if seq[0, -1].item() == end_token:\n",
    "                        new_beams.append((seq, log_p, h))\n",
    "                        continue\n",
    "                    dec_out, nh = self.decoder(seq[:,-1:], h)\n",
    "                    log_probs = F.log_softmax(dec_out.squeeze(1), dim=-1) # (1, vocab_size)\n",
    "                    top_log_probs, top_idx = torch.topk(log_probs, beam_size) # (1, beam_size)\n",
    "                    \n",
    "                    for i in range(beam_size):\n",
    "                        token_id = top_idx[0, i].unsqueeze(0).unsqueeze(0)\n",
    "                        token_log_p = top_log_probs[0, i].item()\n",
    "                        new_seq = torch.cat([seq, token_id], dim=1)\n",
    "                        new_log_p = log_p + token_log_p\n",
    "                        new_beams.append((new_seq, new_log_p, nh))\n",
    "                beams = sorted(new_beams, key=lambda x:x[1], reverse=True)[:beam_size] # 取beam size个    \n",
    "                if all(seq[0, -1].item() == end_token for seq, _, _ in beams):\n",
    "                    break\n",
    "            return beams[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "0f89b8c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "总参数: 13,423,496\n",
      "可训练参数: 13,423,496\n",
      "占用空间: 51.21 MB\n"
     ]
    }
   ],
   "source": [
    "model = Seq2Seq(encoder_vocab_size, decoder_vocab_size, embed_size=128, hidden_size=256, num_layers=2, dropout=0.2).to(device)\n",
    "\n",
    "total_params = sum(p.numel() for p in model.parameters())\n",
    "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(f\"总参数: {total_params:,}\")\n",
    "print(f\"可训练参数: {trainable_params:,}\")\n",
    "print(f\"占用空间: {total_params * 4 / 1024 / 1024:.2f} MB\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "82157179",
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate(\n",
    "    model: Seq2Seq,\n",
    "    english_text: str,\n",
    "    encoder_tokenizer,\n",
    "    decoder_tokenizer,\n",
    "    beam_size=3,\n",
    "    max_length: int = MAX_SEQ_LEN,\n",
    "):\n",
    "    \"\"\"\n",
    "    使用指定的Seq2Seq模型将英文文本翻译为中文。\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "\n",
    "    input_token = encoder_tokenizer(\n",
    "        english_text,\n",
    "        max_length=max_length,\n",
    "        padding=\"max_length\",\n",
    "        truncation=True,\n",
    "        return_tensors=\"pt\",\n",
    "    )\n",
    "    source_seq = input_token[\"input_ids\"].to(device)\n",
    "\n",
    "    start_token_id = decoder_tokenizer.cls_token_id\n",
    "    end_token_id = decoder_tokenizer.sep_token_id\n",
    "\n",
    "    generated_ids = model.generate(\n",
    "        source_seq,\n",
    "        beam_size=beam_size,\n",
    "        max_length=max_length,\n",
    "        start_token=start_token_id,\n",
    "        end_token=end_token_id,\n",
    "    )\n",
    "    generated_text = decoder_tokenizer.decode(\n",
    "        generated_ids[0], skip_special_tokens=True\n",
    "    )\n",
    "\n",
    "    # BERT中文tokenizer解码后可能在字与字之间添加空格\n",
    "    return generated_text.replace(\" \", \"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "765c02b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_translate(model_):\n",
    "\n",
    "    speed_test_cases = [\n",
    "        \"Hello.\",\n",
    "        \"How are you doing today?\",\n",
    "        \"I really enjoy reading books and learning about different cultures around the world.\",\n",
    "        \"The quick brown fox jumps over the lazy dog while the sun is shining brightly in the clear blue sky.\"\n",
    "    ]\n",
    "\n",
    "    for sentence in speed_test_cases:\n",
    "        print(sentence,\"-->\", translate(model_, sentence, encoder_tokenizer, decoder_tokenizer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1eeb0ea",
   "metadata": {},
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "745bc03d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 半精度训练\n",
    "from collections import defaultdict\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from torch import autocast, GradScaler\n",
    "scaler = GradScaler()\n",
    "\n",
    "\n",
    "criterion = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')\n",
    "optimizer = optim.Adam(model.parameters(), \n",
    "                            lr=learning_rate, \n",
    "                            weight_decay=weight_decay)\n",
    "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)\n",
    "\n",
    "history = defaultdict(list)\n",
    "\n",
    "def train_epoch():\n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    progress_bar = tqdm(train_iter, desc=\"Train\")\n",
    "    \n",
    "    for x, y in progress_bar:\n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        \n",
    "        with autocast(device, dtype=torch.bfloat16):\n",
    "            pred = model(x, y[:,:-1])\n",
    "            loss = criterion(pred.permute(0, 2, 1), y[:, 1:])\n",
    "        scaler.scale(loss).backward()\n",
    "        scaler.step(optimizer)\n",
    "        scaler.update()\n",
    "            \n",
    "        # pred = model(x, y[:,:-1])\n",
    "        # loss = criterion(pred.permute(0, 2, 1), y[:, 1:])\n",
    "        # loss.backward()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        train_loss += loss.item()\n",
    "        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})\n",
    "    \n",
    "    return train_loss / len(train_iter)\n",
    "\n",
    "def val_epoch():\n",
    "    model.eval()\n",
    "    val_loss = 0\n",
    "    \n",
    "    with torch.no_grad(): \n",
    "        progress_bar = tqdm(val_iter, desc=\"Validation\")\n",
    "        \n",
    "        for x, y in progress_bar:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)        \n",
    "            \n",
    "            pred = model(x, y[:,:-1]) \n",
    "            loss = criterion(pred.permute(0, 2, 1), y[:, 1:]) \n",
    "            val_loss += loss.item()\n",
    "            progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})\n",
    "\n",
    "    return val_loss / len(val_iter)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd819e5c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1ae46b8a1376436fa63353b4c53b4403",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/4238 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "10dcd7c380ad4b2a94e4462c327ca9ed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/224 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP1/6, Train Loss:3.5681314326615308, Val Loss:2.6202752175075665\n",
      "Saved best model with validation loss: 2.6203\n",
      "Hello. --> 那么。\n",
      "How are you doing today? --> 你们的时候是什么?\n",
      "I really enjoy reading books and learning about different cultures around the world. --> 如果我们的时候,我们必须在这个问题,我们必须在这里的时候,我们可以发现了.\n",
      "The quick brown fox jumps over the lazy dog while the sun is shining brightly in the clear blue sky. --> 在这里的时候，。\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b72c05f074ef401882c83b5d3ffb0789",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/4238 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "939b3f61503e4a4a991ce3feeea11c8b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/224 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP2/6, Train Loss:2.4659975874024553, Val Loss:2.258140161101307\n",
      "Saved best model with validation loss: 2.2581\n",
      "Hello. --> 你好吗？\n",
      "How are you doing today? --> 你能到这个时候了吗？\n",
      "I really enjoy reading books and learning about different cultures around the world. --> 我的朋友是一个很好的,但是,我们可以在这里,我们就是这样的.\n",
      "The quick brown fox jumps over the lazy dog while the sun is shining brightly in the clear blue sky. --> 这个时候,我们的时候,我们看到了一个世界的时候,我们的时候,我们就是这样的.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "019342b36b9e4a269048784b086ea6e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/4238 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ed54c910ba24150abb21044c2f9976b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/224 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP3/6, Train Loss:2.2122994122948496, Val Loss:2.092355420014688\n",
      "Saved best model with validation loss: 2.0924\n",
      "Hello. --> 你好吗？\n",
      "How are you doing today? --> 你今天这个问题是什么？\n",
      "I really enjoy reading books and learning about different cultures around the world. --> 我的朋友们可以看到我们的世界上最好的时候，我可以看到一个很好的。\n",
      "The quick brown fox jumps over the lazy dog while the sun is shining brightly in the clear blue sky. --> 这个时候都看到了一个星期的时候,我们看到了一个星期的时候,他们就是一种不同的.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9dcbb5a39dbb47fa9d8fa9501b1e114a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/4238 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8e422a4d142b4c959b5202cf12865bbd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/224 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP4/6, Train Loss:2.0743875816893387, Val Loss:1.9932746546609061\n",
      "Saved best model with validation loss: 1.9933\n",
      "Hello. --> 你好吗？\n",
      "How are you doing today? --> 你们今天就这样做了吗?\n",
      "I really enjoy reading books and learning about different cultures around the world. --> 如果我喜欢这种情况，我可以看到一个很好的时候。\n",
      "The quick brown fox jumps over the lazy dog while the sun is shining brightly in the clear blue sky. --> 天然后,我们的眼睛里看到了一个星期里的眼睛里,看上去看到了一天.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "493e6814e506446f822d36ef396daa9e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/4238 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ecea81c9e364e2cb6830f7b531ea54b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/224 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP5/6, Train Loss:1.9847149303804428, Val Loss:1.9258415305188723\n",
      "Saved best model with validation loss: 1.9258\n",
      "Hello. --> 你好吗？\n",
      "How are you doing today? --> 你们今天就是这个问题吗？\n",
      "I really enjoy reading books and learning about different cultures around the world. --> 我喜欢这种情况。\n",
      "The quick brown fox jumps over the lazy dog while the sun is shining brightly in the clear blue sky. --> 天天上午6:10时间里看到了两天的时候了.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9c40f9418a314b58a2edaa5354cf826d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/4238 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "best_val_loss = np.inf\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    train_loss = train_epoch()\n",
    "    history['train_loss'].append(train_loss)\n",
    "\n",
    "    val_loss = val_epoch()\n",
    "    history['val_loss'].append(val_loss)\n",
    "    scheduler.step(val_loss)\n",
    "    current_lr = optimizer.param_groups[0]['lr']\n",
    "\n",
    "    print(\"EP{}/{}, Train Loss:{}, Val Loss:{}\".format(epoch+1, NUM_EPOCHS, history['train_loss'][-1], history['val_loss'][-1]))\n",
    "\n",
    "    if val_loss < best_val_loss:\n",
    "        best_val_loss = val_loss\n",
    "        torch.save(model.state_dict(), \"best_seq2seq.pth\")\n",
    "        print(f\"Saved best model with validation loss: {best_val_loss:.4f}\")\n",
    "    \n",
    "    test_translate(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6450214d",
   "metadata": {},
   "source": [
    "## 翻译"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60fccd6c",
   "metadata": {},
   "source": [
    "## 可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a74f846",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "defaultdict(<class 'list'>, {'train_loss': [3.627725058200507, 2.4953166972133785, 2.3628535910618087, 2.363077530072173, 2.3629240704741665, 2.3631410310717333], 'val_loss': [2.642954061073916, 2.322091503334897, 2.3220915155751363, 2.322091528879745, 2.322091529411929, 2.322091516107321]})\n"
     ]
    }
   ],
   "source": [
    "print(history)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b05ed04",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "不是吗？\n",
      "你们的人道吗?\n",
      "\n",
      "《公民权利公约》第一条第一项第二十一条第一项建议。\n"
     ]
    }
   ],
   "source": [
    "print(translate(model, \"Hello\", encoder_tokenizer, decoder_tokenizer))\n",
    "print(translate(model, \"How old are you?\", encoder_tokenizer, decoder_tokenizer))\n",
    "print(translate(model, \"Beautiful is better than ugly.\", encoder_tokenizer, decoder_tokenizer))\n",
    "print(translate(model, \"Special cases aren't special enough to break the rules.\", encoder_tokenizer, decoder_tokenizer))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
